import os
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt
import pickle
from time import time
from logging import getLogger
import torch.nn as nn
from causally.utils.utils import ensure_dir, get_local_time, early_stopping
from causally.trainer.AbstractTrainer import AbstractTrainer

class SKTrainer(AbstractTrainer):


    def __init__(self, config, model):
        super(SKTrainer, self).__init__(config, model)

        self.logger = getLogger()
        self.learner = config['optimizer']
        self.learning_rate = config['learning_rate']
        self.epochs = config['epochs']
        self.eval_step = min(config['eval_step'], self.epochs)
        self.stopping_step = config['stopping_step']
        self.clip_grad_norm = config['clip_grad_norm']
        self.valid_metric_bigger = config['valid_metric_bigger']
        self.test_batch_size = config['eval_batch_size']
        self.device = config['device']
        self.checkpoint_dir = config['checkpoint_dir']
        ensure_dir(self.checkpoint_dir)
        saved_model_file = '{}-{}-{}.pkl'.format(self.config['model'],self.config['dataset'], get_local_time())
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = -1000000
        self.best_valid_result = None

    def fit(self, train_data=None,valid_data=None):

        x,t,y,w = train_data.get_data()
        self.model.calculate_loss(x,t,y,w)


    def criterion(self,x,y):

        return np.mean(np.square(x-y))


    def evaluate(self,treated_data = None,control_data=None):

        treated_x,treated_t,treated_y,_ = treated_data.get_data()
        _, control_t, control_y, _ = control_data.get_data()
        true_ite = treated_y - control_y

        preds = self.model.predict(treated_x, control_t, treated_t)

        pehe2 = self.criterion(preds,true_ite)


        ate = np.abs(np.mean(preds)-np.mean(true_ite))

        pehe = np.sqrt(pehe2)

        return {'pehe':pehe,'ate':ate}
